import torch
import torch.nn as nn
import numpy as np
from deepset import *
from cdm import *
import numpy as np
import copy
import pickle
import csv
from collections import defaultdict
from IPython import embed
from core_shapley import *

#How does NBA all stars team compare to average team?
global ch_id
ch_id = pickle.load(open("data/starters_player_idx.p", "rb"))
global id_ch
id_ch = {v: k for k, v in ch_id.items()}

NUM_PLAYERS = len(ch_id)

def get_idx(names):
    return np.array([int(ch_id[x]) for x in names])
def get_names(idxs):
    return np.array([id_ch[str(i)] for i in idxs])

salary = defaultdict(list)
with open("data/nba_salaries_13_19.csv", "r") as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        nm = row["player"]
        salary[nm].append(float(row["salary"]) / 1e6)

player_team = {}
player_stats = defaultdict(list)
with open("data/Seasons_Stats.csv", "r") as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        if any([len(row[val]) == 0 for val in ["Year", "Player", "PER", "WS", "VORP"]]): continue
        if int(row["Year"]) < 2012: continue

        player_name  = row["Player"]
        per = float(row["PER"])
        ws = float(row["WS"])
        vorp = float(row["VORP"])
        team = str(row["Tm"])

        player_stats[player_name].append(np.array([per, ws, vorp]))

        player_team[player_name] = team

np.average(player_stats["Stephen Curry"], axis=0)


#########################################################################

'''
Load Models
'''

def fhoi_compute_pair_score(model, idxs, absl=False):
    team_score = 0.0
    mat = model.weight.weight.detach().numpy()
    for i in idxs:
        for j in idxs:
            w = mat[i,j]
            if absl: w = np.abs(w)
            team_score += w
    return team_score
score_fn = fhoi_compute_pair_score 

model = FHoi(NUM_PLAYERS)
filepath = "example_models/nba_fhoi.pth"
model.load_state_dict(torch.load(filepath))

def deepset_compute_score(model, idxs, absl=False):
    idxs = torch.LongTensor(idxs)
    higher_order = model.linear2(model.ReLU(model.linear1(model.embedding(idxs).sum(0))))
    first_order = model.bias(idxs).sum()
    if absl:
        higher_order = torch.abs(higher_order)
        first_order = torch.abs(first_order)
    return first_order.detach().numpy() + higher_order.detach().numpy()[0]
score_fn = deepset_compute_score 

model = DeepSet(NUM_PLAYERS, 10)
filepath = "example_models/nba_deepset.pth"
model.load_state_dict(torch.load(filepath))

def linear_compute_score(model, idxs, absl=False):
    idxs = torch.LongTensor(idxs)
    out = model.bias(idxs).sum()
    return out.detach().numpy()
score_fn =linear_compute_score 

model = LR(NUM_PLAYERS)
filepath = "example_models/nba_linear.pth"
model.load_state_dict(torch.load(filepath))

def cdm_compute_score(model, idxs, absl=False):
    idxs = torch.LongTensor(idxs)
    idxs = model.one_hot(idxs).sum(0).float()
    out = model.ws(idxs)
    interaction = torch.dot(model.cs(idxs), model.ts(idxs))
    out += interaction
    return out.detach().numpy()[0]
score_fn = cdm_compute_score

model = CDM(NUM_PLAYERS, embed_dim=10)
filepath = "example_models/nba_cdm.pth"
model.load_state_dict(torch.load(filepath))

#########################################################################

'''
Analyze all star team vs random teams
'''

all_idx = np.arange(len(id_ch))
team_scores = []
for i in range(1000):
    np.random.shuffle(all_idx)
    random_team = copy.deepcopy(all_idx[:5])
    team_scores.append(score_fn(model, random_team, absl=False))

np.save(open("random_team1k.npy","wb"), team_scores)

np.percentile(team_scores, 99.9)

f = lambda x: (x, score_fn(model, get_idx(x)))
out = []
out.append(f(["Rajon Rondo", "Dwyane Wade", "LeBron James", "Carmelo Anthony", "Kevin Garnett"]))
out.append(f(["Chris Paul", "Kobe Bryant", "Kevin Durant", "Blake Griffin", "Dwight Howard"]))
out.append(f(["Dwyane Wade", "Kyrie Irving", "LeBron James", "Paul George", "Carmelo Anthony"]))
out.append(f(["Stephen Curry", "Kobe Bryant", "Kevin Durant", "Blake Griffin", "Kevin Love"]))
out.append(f(["John Wall", "Kyle Lowry", "LeBron James", "Pau Gasol", "Carmelo Anthony"]))
out.append(f(["Stephen Curry", "Kobe Bryant", "Anthony Davis", "Marc Gasol", "Blake Griffin"]))
out.append(f(["Dwyane Wade", "Kyle Lowry", "LeBron James", "Paul George", "Carmelo Anthony"]))
out.append(f(["Stephen Curry", "Russell Westbrook", "Kobe Bryant", "Kevin Durant", "Kawhi Leonard"]))
out.append(f(["Kyrie Irving", "DeMar DeRozan", "LeBron James", "Jimmy Butler", "Giannis Antetokounmpo"]))
out.append(f(["Stephen Curry", "James Harden", "Kevin Durant", "Kawhi Leonard", "Anthony Davis"]))
out.append(f(["Kyrie Irving", "DeMar DeRozan", "LeBron James", "Joel Embiid", "Giannis Antetokounmpo"]))
out.append(f(["Stephen Curry", "James Harden", "Kevin Durant", "DeMarcus Cousins", "Anthony Davis"]))

nm = "all_star.csv"
out = [",".join(players)+","+str(val) for players, val in out]
out = [str(2013 + int(i /2)) + ","+out[i] for i in range(len(out))]
with open(nm, "w") as f:
    f.write("\n".join(out))


f = lambda x: (x, score_fn(model, get_idx(x)))
out = []
out.append(f(["Jason Terry", "Ray Allen", "Udonis Haslem", "Kenyon Martin", "Chris Wilcox"]))
out.append(f(["Eric Bledsoe", "Jodie Meeks", "DeAndre Liggins", "Matt Barnes", "Jordan Hill"]))
out.append(f(["Ray Allen", "Jarrett Jack", "Udonis Haslem", "Solomon Hill", "Kenyon Martin"]))
out.append(f(["Steve Blake", "Nick Young", "Thabo Sefolosha", "Glen Davis", "Dante Cunningham"]))
out.append(f(["Andre Miller", "Lou Williams", "Shawn Marion", "Taj Gibson", "Cleanthony Early"]))
out.append(f(["Shaun Livingston", "Nick Young", "Ryan Anderson", "Kosta Koufos", "Spencer Hawes"]))
out.append(f(["Tyler Johnson", "Cory Joseph", "Richard Jefferson", "Solomon Hill", "Derrick Williams"]))
out.append(f(["Shaun Livingston", "Cameron Payne", "Nick Young", "Josh Huestis", "Kyle Anderson"]))
out.append(f(["Deron Williams", "Norman Powell", "Richard Jefferson", "Doug McDermott", "Khris Middleton"]))
out.append(f(["Shaun Livingston", "Tyler Ennis", "Matt Barnes", "Kyle Anderson", "Omer Asik"]))
out.append(f(["Terry Rozier", "Norman Powell", "Richard Jefferson", "Amir Johnson", "Jabari Parker"]))
out.append(f(["Shaun Livingston", "Eric Gordon", "Andre Iguodala", "Josh Smith", "Omer Asik"]))

nm = "all_star_backup.csv"
out = [",".join(players)+","+str(val) for players, val in out]
out = [str(2013 + int(i /2)) + ","+out[i] for i in range(len(out))]
with open(nm, "w") as f:
    f.write("\n".join(out))

#######################################################################

'''
Computes Shapley
'''

from core_shapley import *

text_path = "data/starters_nba_team_result.txt"
match_outcome = open(text_path).read().split("\n\n")
if len(match_outcome[-1]) == 0: match_outcome.pop()

player_shaps = defaultdict(list)
for line in match_outcome:
    line = line.split("\n")
    stats = line[2:]
    all_idxs = np.array([int(x.split(";")[0]) for x in stats])
    winner = np.array([x.split(";")[-1] for x in stats]) == "Win"
    loser = np.invert(winner)
    winner = all_idxs[winner]
    loser = all_idxs[loser]

    shap_vals = [compute_shapley(i, winner, model, score_fn) for i in winner]
    for i, player in enumerate(winner):
        player_shaps[player].append(shap_vals[i])

    shap_vals = [compute_shapley(i, loser, model, score_fn) for i in loser]
    for i, player in enumerate(loser):
       player_shaps[player].append(shap_vals[i])

##########################################################################

'''
Saves stats
'''

comparison = []
for idx in player_shaps:
    if id_ch[str(idx)] in player_stats and id_ch[str(idx)] in salary:
        name = id_ch[str(idx)]
        team = player_team[name]
        pay = np.average(salary[name])
        per, ws, vorp = np.average(player_stats[name], axis=0)
        comparison.append((name, team, ws, vorp, pay, np.average(player_shaps[idx])))

with open("nba_stats.csv", "w") as f:
    f.write("name, team, win_share, vorp, average_salary, shapley, bias\n")
    f.write("\n".join([",".join([str(x) for x in player]) for player in comparison]))

